import gym
import numpy as np
from copy import deepcopy as cp
import torch
from train_victim_wm import WorldModel, MLP
from mnp_attack import get_mnp_attacked_observation

class VictimEnvWrapper(gym.Wrapper):
    def __init__(self, victim_env):
        super().__init__(victim_env)
        self.model_adv = None
        self.last_obs_seen = None
        self.other_is_deterministic = False
        self.worldmodel = WorldModel()
        self.additional_metrics = ["attack_detected", "victim_rewards_wm_defended"]
        self.use_worldmodel_to_shutoff = False
        self.run_mnp_attack = False

    def reset(self):
        self.last_obs_seen = None
        self.worldmodel.reset()

        obs = self.env.reset()
        obs_seen = self.get_obs(obs, action=None)

        return obs_seen

    def get_obs(self, obs, action):
        obs_adv = self.model_adv.env.envs[0].get_obs_adversary(obs_true=obs, last_obs_seen=self.last_obs_seen, last_action=action)
        action_adv = self.model_adv.predict(obs_adv,
                                              state=None,
                                              deterministic=self.other_is_deterministic)[0]
        obs_seen = self.model_adv.env.envs[0].get_adversarial_obs(action_adv, obs, last_obs_seen=self.last_obs_seen, last_victim_action=action)

        return obs_seen

    def step(self, action):
        obs, reward, done, infos = self.env.step(action)

        obs_seen = self.get_obs(obs, action)

        if self.use_worldmodel_to_shutoff:
            if self.worldmodel.am_i_attacked(self.last_obs_seen, action, obs_seen):
                done = True # does this really terminate the episode, given that I am just in a wrapper?
                infos["attack_detected"] = True
                infos["victim_rewards_wm_defended"] = reward
            else:
                infos["attack_detected"] = False
                infos["victim_rewards_wm_defended"] = reward
        else:
            infos["attack_detected"] = -1
            infos["victim_rewards_wm_defended"] = -1


        self.last_obs_seen = obs_seen

        return obs_seen, reward, done, infos
